package com.utils.legao;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.*;

public class MosaicMaker {

    private Logger logger = LoggerFactory.getLogger(MosaicMaker.class);
    // 图库路径
    private String dbPath;
    // 目标图片路径
    private String aimPath;
    // 图片输出路径
    private String outPath;
    // 默认子图宽
    private int subWidth;
    // 默认子图高
    private int subHeight;
    // 粒度
    private int unitMin;
    // 成像方式
    private String mode;
    // 默认生成图宽
    private int defaultW;
    // 默认生成图高
    private int defaultH;
    // 每张素材最多出现的次数
    private int max;
    // 加载图库使用的线程数
    private int threadNum;
    // 加载素材系数，越大原图越清晰(范围100~500)
    private int rate;
    // 是否使用树
    private boolean useTree = false;
    private Map<String, ImageInfo> map = new ConcurrentHashMap<>();
    private AVLTree<BufferedImage> tree = new AVLTree<>();

    public MosaicMaker(String dbPath, String aimPath, String outPath, int rate) {
        this(dbPath, aimPath, outPath, 32, 32, 5, Mode.RGB, 1920, 1080, 300, 20, rate);
    }

    private MosaicMaker(String dbPath, String aimPath, String outPath, int subWidth, int subHeight, int unitMin, String mode, int defaultW, int defaultH, int max, int threadNum, int rate) {
        this.dbPath = dbPath;
        this.aimPath = aimPath;
        this.outPath = outPath;
        this.subWidth = subWidth;
        this.subHeight = subHeight;
        this.unitMin = unitMin;
        this.mode = mode;
        this.defaultW = defaultW;
        this.defaultH = defaultH;
        this.max = max;
        this.threadNum = threadNum;
        this.rate = rate;
    }

    public void setMode(String mode) {
        this.mode = mode;
    }

    public int getDefaultW() {
        return defaultW;
    }

    public int getDefaultH() {
        return defaultH;
    }

    public void make() throws IOException {
        File aimFile = new File(aimPath);
        BufferedImage aimIm = ImageIO.read(aimFile);
        int aimWidth = aimIm.getWidth();
        int aimHeight = aimIm.getHeight();
        if (!calSubIm(aimWidth, aimHeight)) {
            aimWidth = defaultW;
            aimHeight = defaultH;
            //使用默认尺寸
            aimIm = ImageUtil.resize(aimIm, aimWidth, aimHeight);
        }
        readAllImage();
        core(aimIm);
    }

    private void core(BufferedImage aimIm) throws IOException {
        int width = aimIm.getWidth();
        int height = aimIm.getHeight();
        if (width != height) {
            height = defaultW;
        }
        long start = System.currentTimeMillis();
        ExecutorService pool = Executors.newFixedThreadPool(threadNum);
        BufferedImage newIm = new BufferedImage(width, height, aimIm.getType());
        Graphics2D g = newIm.createGraphics();
        try {
            int w = width / subWidth;
            CountDownLatch latch = new CountDownLatch(w);
            for (int i = 0; i < w; i++) {
                int finalI = i;
                pool.execute(() -> {
                    for (int j = 0; j < w; j++) {
                        int x = finalI * subWidth;
                        int y = j * subHeight;
                        BufferedImage curAimSubIm = aimIm.getSubimage(x, y, subWidth, subHeight);
                        BufferedImage fitSubIm = findFitIm(curAimSubIm);
                        g.drawImage(fitSubIm, x, y, subWidth, subHeight, null);
                    }
                    latch.countDown();
                });
            }
            latch.await(15, TimeUnit.SECONDS);
        } catch (Exception ignored) {

        } finally {
            pool.shutdown();
        }
        logger.info("拼图完成，耗时" + (System.currentTimeMillis() - start) + "毫秒");
        ImageUtil.save(newIm, outPath);
    }

    //搜索合适子图
    private BufferedImage findFitIm(BufferedImage image) {
        switch (mode) {
            case Mode.RGB:
                if (useTree) {
                    return tree.getCloseByRGB(ImageUtil.calKey(image, mode));
                } else {
                    return findByRGB(image);
                }
            case Mode.GRAY:
                if (useTree) {
                    return tree.getCloseByGray(ImageUtil.calKey(image, mode));
                } else {
                    return findByGRAY(image);
                }
            case Mode.PHASH:
                return findByPHASH(image);
            default:
                return null;
        }
    }

    private BufferedImage findByRGB(BufferedImage image) {
        String[] keys = ImageUtil.calKey(image, mode).split("-");
        float r = Float.parseFloat(keys[0]);
        float g = Float.parseFloat(keys[1]);
        float b = Float.parseFloat(keys[2]);
        float min = Float.MAX_VALUE;
        String indexK = null;
        for (String k : map.keySet()) {
            String[] mk = k.split("-");
            float mr = Float.parseFloat(mk[0]);
            float mg = Float.parseFloat(mk[1]);
            float mb = Float.parseFloat(mk[2]);
            float curDif = Math.abs(mr - r) + Math.abs(mg - g) + Math.abs(mb - b);
            if (min > curDif && map.get(k).max > 0) {
                min = curDif;
                indexK = k;
            }
        }
        ImageInfo info = map.get(indexK);
        info.max = info.max - 1;
        return info.im;
    }

    private BufferedImage findByGRAY(BufferedImage image) {
        String key = ImageUtil.calKey(image, mode);
        float gray = Float.parseFloat(key);
        float min = Float.MAX_VALUE;
        String indexK = null;
        for (String k : map.keySet()) {
            float curGray = Float.parseFloat(k);
            float curDif = Math.abs(curGray - gray);
            if (curDif < min) {
                min = curDif;
                indexK = k;
            }
        }
        ImageInfo info = map.get(indexK);
        info.max = info.max - 1;
        return info.im;
    }

    private BufferedImage findByPHASH(BufferedImage image) {
        String key = ImageUtil.calKey(image, mode);
        int length = key.length();
        int min = Integer.MAX_VALUE;
        String indexK = null;
        for (String k : map.keySet()) {
            int curDif = 0;
            for (int i = 0; i < length; i++) {
                if (key.charAt(i) != k.charAt(i)) {
                    curDif++;
                }
            }

            if (curDif < min) {
                min = curDif;
                indexK = k;
            }
        }
        ImageInfo info = map.get(indexK);
        info.max = info.max - 1;
        return info.im;
    }

    //读取图库
    private void readAllImage() {
        File dir = new File(this.dbPath);
        File[] files = dir.listFiles();
        long start = System.currentTimeMillis();
        ExecutorService pool = Executors.newFixedThreadPool(threadNum);
        assert files != null;
        int size = files.length;
        ReadTask[] readTask = new ReadTask[threadNum];
        CountDownLatch latch = new CountDownLatch(threadNum);
        for (int i = 0; i < size; i++) {
            if (files[i].isFile()) {
                int index = i % threadNum;
                if (readTask[index] == null) readTask[index] = new ReadTask(latch, subWidth, subHeight);
                readTask[index].add(files[i]);
            }
        }
        for (int i = 0; i < threadNum; i++) {
            pool.execute(readTask[i]);
        }
        try {
            latch.await();
        } catch (InterruptedException e) {
            e.printStackTrace();
        } finally {
            pool.shutdown();
        }
        if (useTree) {
            logger.info("共读取" + tree.size() + "张图片");
        } else {
            logger.info("共读取" + map.size() + "张图片");
        }
        logger.info("读取图库完成，耗时" + (System.currentTimeMillis() - start) + "毫秒");
    }

    private class ReadTask implements Runnable {

        private CountDownLatch latch;
        private List<File> files = new ArrayList<>();
        private int w;
        private int h;

        ReadTask(CountDownLatch latch, int w, int h) {
            this.latch = latch;
            this.w = w;
            this.h = h;
        }

        public void add(File file) {
            files.add(file);
        }

        @Override
        public void run() {
            for (File f : files) {
                if (f.isFile()) {
                    BufferedImage im = null;
                    try {
                        im = ImageUtil.resize(ImageIO.read(f), w, h);
                    } catch (IOException e) {
                        e.printStackTrace();
                    }
                    if (useTree) {
                        tree.insert(ImageUtil.calKey(im, mode), im);
                    } else {
                        map.put(ImageUtil.calKey(im, mode), new ImageInfo(max, im));
                    }
                }
            }
            latch.countDown();
        }
    }

    /**
     * 计算子团尺寸
     */
    private boolean calSubIm(int w, int h) {
        int g = gcd(w, h);
        //g太小则使用默认尺寸
        if (g < 20) {
            this.defaultH = h;
            this.defaultW = w;
            if (w < h) {
                h = defaultW;
            } else {
                w = defaultH;
            }
            g = rate;
        }
        //长宽一样时使用320
        if (g == w) {
            g = rate;
        }
        subWidth = unitMin * (w / g);
        subHeight = unitMin * (h / g);
        return true;
    }

    private int gcd(int a, int b) {
        int m = Math.max(a, b);
        int n = Math.min(a, b);
        int r = m % n;
        while (r != 0) {
            m = n;
            n = r;
            r = m % n;
        }
        // 返回最大公约数
        return n;
    }
}
